from cleanlab_tlm import TrustworthyRAG, get_default_evals

from lfx.custom import Component
from lfx.io import (
    BoolInput,
    DropdownInput,
    MessageTextInput,
    Output,
    SecretStrInput,
)
from lfx.schema.message import Message


class CleanlabRAGEvaluator(Component):
    """A component that evaluates the quality of RAG (Retrieval-Augmented Generation) outputs using Cleanlab.

    This component takes a query, retrieved context, and generated response from a RAG pipeline,
    and uses Cleanlab's evaluation algorithms to assess various aspects of the RAG system's performance.

    The component can evaluate:
    - Overall trustworthiness of the LLM generated response
    - Context sufficiency (whether the retrieved context contains information needed to answer the query)
    - Response groundedness (whether the response is supported directly by the context)
    - Response helpfulness (whether the response effectively addresses the user's query)
    - Query ease (whether the user query seems easy for an AI system to properly handle, useful to diagnose
      queries that are: complex, vague, tricky, or disgruntled-sounding)

    Outputs:
        - Trust Score: A score between 0-1 corresponding to the trustworthiness of the response. A higher score
          indicates a higher confidence that the response is correct/good.
        - Explanation: An LLM generated explanation of the trustworthiness assessment
        - Other Evals: Additional evaluation metrics for selected evaluation types in the "Controls" tab
        - Evaluation Summary: A comprehensive summary of context, query, response, and selected evaluation results

    This component works well in conjunction with the CleanlabRemediator to create a complete trust evaluation
    and remediation pipeline.

    More details on the evaluation metrics can be found here: https://help.cleanlab.ai/tlm/use-cases/tlm_rag/
    """

    display_name = "Cleanlab RAG Evaluator"
    description = "Evaluates context, query, and response from a RAG pipeline using Cleanlab and outputs trust metrics."
    icon = "Cleanlab"
    name = "CleanlabRAGEvaluator"

    inputs = [
        SecretStrInput(
            name="api_key",
            display_name="Cleanlab API Key",
            info="Your Cleanlab API key.",
            required=True,
        ),
        DropdownInput(
            name="model",
            display_name="Cleanlab Evaluation Model",
            options=[
                "gpt-4.1",
                "gpt-4.1-mini",
                "gpt-4.1-nano",
                "o4-mini",
                "o3",
                "gpt-4.5-preview",
                "gpt-4o-mini",
                "gpt-4o",
                "o3-mini",
                "o1",
                "o1-mini",
                "gpt-4",
                "gpt-3.5-turbo-16k",
                "claude-3.7-sonnet",
                "claude-3.5-sonnet-v2",
                "claude-3.5-sonnet",
                "claude-3.5-haiku",
                "claude-3-haiku",
                "nova-micro",
                "nova-lite",
                "nova-pro",
            ],
            info="The model Cleanlab uses to evaluate the context, query, and response. This does NOT need to be "
            "the same model that generated the response.",
            value="gpt-4o-mini",
            required=True,
            advanced=True,
        ),
        DropdownInput(
            name="quality_preset",
            display_name="Quality Preset",
            options=["base", "low", "medium"],
            value="medium",
            info="This determines the accuracy, latency, and cost of the evaluation. Higher quality is generally "
            "slower but more accurate.",
            required=True,
            advanced=True,
        ),
        MessageTextInput(
            name="context",
            display_name="Context",
            info="The context retrieved for the given query.",
            required=True,
        ),
        MessageTextInput(
            name="query",
            display_name="Query",
            info="The user's query.",
            required=True,
        ),
        MessageTextInput(
            name="response",
            display_name="Response",
            info="The response generated by the LLM.",
            required=True,
        ),
        BoolInput(
            name="run_context_sufficiency",
            display_name="Run Context Sufficiency",
            value=False,
            advanced=True,
        ),
        BoolInput(
            name="run_response_groundedness",
            display_name="Run Response Groundedness",
            value=False,
            advanced=True,
        ),
        BoolInput(
            name="run_response_helpfulness",
            display_name="Run Response Helpfulness",
            value=False,
            advanced=True,
        ),
        BoolInput(
            name="run_query_ease",
            display_name="Run Query Ease",
            value=False,
            advanced=True,
        ),
    ]

    outputs = [
        Output(display_name="Response", name="response_passthrough", method="pass_response", types=["Message"]),
        Output(display_name="Trust Score", name="trust_score", method="get_trust_score", types=["number"]),
        Output(display_name="Explanation", name="trust_explanation", method="get_trust_explanation", types=["Message"]),
        Output(display_name="Other Evals", name="other_scores", method="get_other_scores", types=["Data"]),
        Output(
            display_name="Evaluation Summary",
            name="evaluation_summary",
            method="get_evaluation_summary",
            types=["Message"],
        ),
    ]

    def _evaluate_once(self):
        if not hasattr(self, "_cached_result"):
            try:
                self.status = "Configuring selected evals..."
                default_evals = get_default_evals()
                enabled_names = []
                if self.run_context_sufficiency:
                    enabled_names.append("context_sufficiency")
                if self.run_response_groundedness:
                    enabled_names.append("response_groundedness")
                if self.run_response_helpfulness:
                    enabled_names.append("response_helpfulness")
                if self.run_query_ease:
                    enabled_names.append("query_ease")

                selected_evals = [e for e in default_evals if e.name in enabled_names]

                validator = TrustworthyRAG(
                    api_key=self.api_key,
                    quality_preset=self.quality_preset,
                    options={"log": ["explanation"], "model": self.model},
                    evals=selected_evals,
                )

                self.status = f"Running evals: {[e.name for e in selected_evals]}"
                self._cached_result = validator.score(
                    query=self.query,
                    context=self.context,
                    response=self.response,
                )
                self.status = "Evaluation complete."

            except Exception as e:  # noqa: BLE001
                self.status = f"Evaluation failed: {e!s}"
                self._cached_result = {}
        return self._cached_result

    def pass_response(self) -> Message:
        self.status = "Passing through response."
        return Message(text=self.response)

    def get_trust_score(self) -> float:
        score = self._evaluate_once().get("trustworthiness", {}).get("score", 0.0)
        self.status = f"Trust Score: {score:.3f}"
        return score

    def get_trust_explanation(self) -> Message:
        explanation = self._evaluate_once().get("trustworthiness", {}).get("log", {}).get("explanation", "")
        self.status = "Trust explanation extracted."
        return Message(text=explanation)

    def get_other_scores(self) -> dict:
        result = self._evaluate_once()

        selected = {
            "context_sufficiency": self.run_context_sufficiency,
            "response_groundedness": self.run_response_groundedness,
            "response_helpfulness": self.run_response_helpfulness,
            "query_ease": self.run_query_ease,
        }

        filtered_scores = {key: result[key]["score"] for key, include in selected.items() if include and key in result}

        self.status = f"{len(filtered_scores)} other evals returned."
        return filtered_scores

    def get_evaluation_summary(self) -> Message:
        result = self._evaluate_once()

        query_text = self.query.strip()
        context_text = self.context.strip()
        response_text = self.response.strip()

        trust = result.get("trustworthiness", {}).get("score", 0.0)
        trust_exp = result.get("trustworthiness", {}).get("log", {}).get("explanation", "")

        selected = {
            "context_sufficiency": self.run_context_sufficiency,
            "response_groundedness": self.run_response_groundedness,
            "response_helpfulness": self.run_response_helpfulness,
            "query_ease": self.run_query_ease,
        }

        other_scores = {key: result[key]["score"] for key, include in selected.items() if include and key in result}

        metrics = f"Trustworthiness: {trust:.3f}"
        if trust_exp:
            metrics += f"\nExplanation: {trust_exp}"
        if other_scores:
            metrics += "\n" + "\n".join(f"{k.replace('_', ' ').title()}: {v:.3f}" for k, v in other_scores.items())

        summary = (
            f"Query:\n{query_text}\n"
            "-----\n"
            f"Context:\n{context_text}\n"
            "-----\n"
            f"Response:\n{response_text}\n"
            "------------------------------\n"
            f"{metrics}"
        )

        self.status = "Evaluation summary built."
        return Message(text=summary)
